In [1]:
import cPickle
import gzip

from breze.learn.data import one_hot
from breze.learn.base import cast_array_to_local_type
from breze.learn.utils import tile_raster_images

import climin.stops


import climin.initialize

from breze.learn import sgvb, hvi
from matplotlib import pyplot as plt
from matplotlib import cm

import numpy as np

#import fasttsne

from IPython.html import widgets
%matplotlib inline 

import theano
theano.config.compute_test_value = 'ignore'#'raise'
Couldn't import dot_parser, loading of dot files will not be possible.
//anaconda/lib/python2.7/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.
  "`IPython.html.widgets` has moved to `ipywidgets`.", ShimWarning)
In [30]:
datafile = '../mnist.pkl.gz'
# Load data.                                                                                                   

with gzip.open(datafile,'rb') as f:                                                                        
    train_set, val_set, test_set = cPickle.load(f)                                                       

X, Z = train_set                                                                                               
VX, VZ = val_set
TX, TZ = test_set

Z = one_hot(Z, 10)
VZ = one_hot(VZ, 10)
TZ = one_hot(TZ, 10)

X_no_bin = X
VX_no_bin = VX
TX_no_bin = TX

# binarize the MNIST data
np.random.seed(0)
VX = np.random.binomial(1, np.tile(VX, (5, 1))) * 1.0
TX = np.random.binomial(1, np.tile(TX, (5, 1))) * 1.0
X = np.random.binomial(1, X) * 1.0

image_dims = 28, 28

X_np, Z_np, VX_np, VZ_np, TX_np, TZ_np, X_no_bin_np, VX_no_bin_np, TX_no_bin_np = X, Z, VX, VZ, TX, TZ, X_no_bin, VX_no_bin, TX_no_bin
X, Z, VX, VZ, TX, TZ, X_no_bin, VX_no_bin, TX_no_bin = [cast_array_to_local_type(i) 
                                                        for i in (X, Z, VX,VZ, TX, TZ, X_no_bin, VX_no_bin, TX_no_bin)]
print X.shape
(50000, 784)
In [31]:
fig, ax = plt.subplots(figsize=(9, 9))

img = tile_raster_images(X[:64], image_dims, (8, 8), (1, 1))
ax.imshow(img, cmap=cm.binary)
Out[31]:
<matplotlib.image.AxesImage at 0x10812e690>
In [45]:
batch_size = 200
# optimizer = 'rmsprop', {'step_rate': 1e-4, 'momentum': 0.95, 'decay': .95, 'offset': 1e-6}
# optimizer = 'adam', {'step_rate': .5, 'momentum': 0.9, 'decay': .95, 'offset': 1e-6}
optimizer = 'adam'

fast_dropout = False

if fast_dropout:
    class MyVAE(sgvb.FastDropoutVariationalAutoEncoder,
                sgvb.FastDropoutMlpGaussLatentVAEMixin,
                sgvb.FastDropoutMlpBernoulliVisibleVAEMixin):
        pass
    kwargs = {
        'p_dropout_inpt': .1,
        'p_dropout_hiddens': [.2, .2],
    }
    print 'yeah'

else:
    class MyVAE(sgvb.VariationalAutoEncoder,
                hvi.MlpGaussLatentVAEMixin, 
                hvi.MlpBernoulliVisibleVAEMixin, 
                ):
        pass
    kwargs = {}


# This is the number of random variables NOT the size of 
# the sufficient statistics for the random variables.
n_latents = 20
n_hidden = 200

m = MyVAE(X.shape[1], [n_hidden, n_hidden], n_latents, [n_hidden, n_hidden], ['rectifier'] * 2, ['softplus'] * 2,
          optimizer=optimizer, batch_size=batch_size,
          **kwargs)
m.binarize_data = True
#m.exprs['loss'] += 0.001 * (m.parameters.enc_in_to_hidden ** 2).sum() / m.exprs['inpt'].shape[0]

climin.initialize.randomize_normal(m.parameters.data, 0, 1e-2)

#climin.initialize.sparsify_columns(m.parameters['enc_in_to_hidden'], 15)
#climin.initialize.sparsify_columns(m.parameters['enc_hidden_to_hidden_0'], 15)
#climin.initialize.sparsify_columns(m.parameters['dec_hidden_to_out'], 15)

#f_latent_mean = m.function(['inpt'], 'latent_mean')
#f_sample = m.function([('gen', 'layer-0-inpt')], 'output')
#f_recons = m.function(['inpt'], 'output')
In [47]:
FILENAME = 'vae_gen2_recog2_late20_hid200_softplus_np_R3.pkl'

# In[5]:
old_best_params = None
f = open(FILENAME, 'rb')
np_array = cPickle.load(f)
old_best_params = cast_array_to_local_type(np_array)
f.close()

print old_best_params.shape
print m.parameters.data.shape

m.parameters.data = old_best_params.copy()
old_best_loss = m.score(VX)
print old_best_loss
(407224,)
(407224,)
95.81616
In [21]:
#m.estimate_nll(X[:10])
print m.score(TX)
print m.score(X)
131.74066
127.75316
In [22]:
print m.estimate_nll(TX, 100)
129.554060491
In [ ]:
max_passes = 250
max_iter = max_passes * X.shape[0] / batch_size
n_report = X.shape[0] / batch_size

stop = climin.stops.AfterNIterations(max_iter)
#pause = climin.stops.ModuloNIterations(n_report)
pause = climin.stops.always

for i, info in enumerate(m.powerfit((X,), (VX,), stop, pause)):
    print i, info['loss'], info['val_loss']
In [ ]:
from theano.printing import debugprint
debugprint(m._f_dloss)
In [ ]:
m.parameters.data[...] = info['best_pars']
In [48]:
f_sample = m.function([m.recog_sample], m.vae.gen.sample())
f_recons = m.function(['inpt'], m.vae.gen.sample())
In [49]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))

#S = f_sample(cast_array_to_local_type(np.random.randn(64, m.n_latent).astype('float32')))[:, :784].astype('float32')
#img = tile_raster_images(S, image_dims, (8, 8), (1, 1))
img = tile_raster_images(X[:64], image_dims, (8, 8), (1, 1))
axs[0].imshow(img, cmap=cm.binary)

R = f_recons(X[:64])[:, :784].astype('float32')
img = tile_raster_images(R, image_dims, (8, 8), (1, 1))

axs[1].imshow(img, cmap=cm.binary)
Out[49]:
<matplotlib.image.AxesImage at 0x10893bd10>
In [63]:
#fig, axs = plt.subplots(1, 2, figsize=(18, 9))
#img = tile_raster_images(m.parameters[m.vae.recog.mlp.layers[0].weights].T, image_dims, (10, 10), (1, 1))
#axs[0].imshow(img, cmap=cm.binary)

#img = tile_raster_images(m.parameters[m.vae.gen.mlp.layers[-1].weights], image_dims, (10, 10), (1, 1))
#axs[1].imshow(img, cmap=cm.binary)
In [51]:
f_L = m.function([m.vae.inpt], m.vae.recog.stt)
In [52]:
L = f_L(X)
In [64]:
fig, axs = plt.subplots(n_latents, 1, figsize=(9, 9*n_latents))
plot_dim = 0
for i in range(n_latents):
    axs[i].scatter(L[:, i], L[:, plot_dim], c=Z[:].argmax(1), lw=0, s=10, alpha=.2)
In [ ]: